{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e504f9ed",
   "metadata": {},
   "source": [
    "# Working with Meshes\n",
    "\n",
    "This tutorial shows how to expedite working with kaolin operations using the `SurfaceMesh` container class. We will cover import/export, batching strategies, managing mesh data, rendering and visualization. \n",
    "\n",
    "Note that material support of `SurfaceMesh` is currently limited and is on the roadmap."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4d8db10b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import logging\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import torch\n",
    "\n",
    "import kaolin as kal\n",
    "from kaolin.rep import SurfaceMesh\n",
    "\n",
    "from tutorial_common import COMMON_DATA_DIR\n",
    "\n",
    "def sample_mesh_path(fname):\n",
    "    return os.path.join(COMMON_DATA_DIR, 'meshes', fname)\n",
    "\n",
    "def print_tensor(t, **kwargs):\n",
    "    print(kal.utils.testing.tensor_info(t, **kwargs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16c66902",
   "metadata": {},
   "source": [
    "## Understanding the SurfaceMesh Container\n",
    "\n",
    "`SurfaceMesh` can store information about a single mesh and a batch of meshes, following three batching\n",
    "strategies:\n",
    "   * `NONE` - a single mesh, not batched\n",
    "   * `FIXED` - a batch of meshes with fixed topology (faces are fixed)\n",
    "   * `LIST` - a list of variable topology meshes\n",
    "\n",
    "Automatically converting between these batching strategies allows quickly connecting to various Kaolin operations. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7ea7171c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected SurfaceMesh contents for batching strategy FIXED\n",
      "            vertices: (torch.FloatTensor) of shape ['B', 'V', 3]\n",
      "               faces: (torch.IntTensor)   of shape ['F', 'FSz']\n",
      "       face_vertices: (torch.FloatTensor) of shape ['B', 'F', 'FSz', 3]\n",
      "             normals: (torch.FloatTensor) of shape ['B', 'VN', 3]\n",
      "    face_normals_idx: (torch.IntTensor)   of shape ['B', 'F', 'FSz']\n",
      "        face_normals: (torch.FloatTensor) of shape ['B', 'F', 'FSz', 3]\n",
      "                 uvs: (torch.FloatTensor) of shape ['B', 'U', 2]\n",
      "        face_uvs_idx: (torch.IntTensor)   of shape ['B', 'F', 'FSz']\n",
      "            face_uvs: (torch.FloatTensor) of shape ['B', 'F', 'FSz', 2]\n",
      "      vertex_normals: (torch.FloatTensor) of shape ['B', 'V', 3]\n",
      "     vertex_tangents: (torch.FloatTensor) of shape ['B', 'V', 3]\n",
      "material_assignments: (torch.IntTensor)   of shape ['B', 'F']\n",
      "           materials: non-tensor attribute\n",
      "\n",
      "Key: B - batch size, F - number of faces, FSz - face size, V - number of vertices,\n",
      "     VN - number of normals, U - number of UVs\n"
     ]
    }
   ],
   "source": [
    "# To get a sense for what the mesh can contain for different batching strategies, run:\n",
    "\n",
    "print(SurfaceMesh.attribute_info_string(SurfaceMesh.Batching.FIXED))\n",
    "print('\\nKey: B - batch size, F - number of faces, FSz - face size, V - number of vertices,'\n",
    "      '\\n     VN - number of normals, U - number of UVs')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf6648b8",
   "metadata": {},
   "source": [
    "## Constructor and Auto-computable Attributes\n",
    "\n",
    "A `SurfaceMesh` can be constructed from torch tensors with names, types and sizes as described above. Only `faces` and `vertices` are required, both of which are allowed to contain zero elements, and **many attributes can be computed automatically**. \n",
    "\n",
    "Important settings passed to the constructor:\n",
    "* `unset_attributes_return_none` (default: `True`) - set this to `False` to raise an error when accessing mesh attributes that are missing\n",
    "* `allow_auto_compute` (default: `True`) - set this to `False` to disable computation of attributes such as `face_uvs` and `vertex_normals`\n",
    "* `strict_checks` (default: `True`) - set this to `False` to allow setting attributes to unexpected shapes\n",
    "\n",
    "You can also set `mesh.unset_attributes_return_none` or `mesh.allow_auto_compute` later to change mesh behavior."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8de1ea19",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vertices: [10, 3] (torch.float32)[cpu]  \n",
      "faces: [5, 3] (torch.int64)[cpu]  \n",
      "face_vertices: None\n",
      "face_vertices (auto-computed): [5, 3, 3] (torch.float32)[cpu]  \n"
     ]
    }
   ],
   "source": [
    "# Let's construct a simple unbatched mesh\n",
    "V, F, Fsz = 10, 5, 3\n",
    "mesh = kal.rep.SurfaceMesh(vertices=torch.rand((V, 3)).float(),\n",
    "                           faces=torch.randint(0, V, (F, Fsz)).long(),\n",
    "                           allow_auto_compute=False)  # disable auto-compute for now\n",
    "print_tensor(mesh.vertices, name='vertices')\n",
    "print_tensor(mesh.faces, name='faces')\n",
    "print_tensor(mesh.face_vertices, name='face_vertices')\n",
    "\n",
    "# Now let's enable auto-compute\n",
    "mesh.allow_auto_compute=True\n",
    "print_tensor(mesh.face_vertices, name='face_vertices (auto-computed)')                                "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44055abc",
   "metadata": {},
   "source": [
    "Batched meshes can also be instantiated by passing batched inputs to the constructor, for example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "59a0b773",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Instantiated mesh with batching FIXED and length 3\n",
      "Instantiated mesh with batching LIST and length 2\n"
     ]
    }
   ],
   "source": [
    "# FIXED: inputs are batched tensors with fixed faces across batches\n",
    "B, VN = 3, 20\n",
    "mesh_fixed = kal.rep.SurfaceMesh(vertices=torch.rand((B, V, 3)).float(),\n",
    "                                 faces=torch.randint(0, V, (F, Fsz)).long(),\n",
    "                                 normals=torch.rand((B, VN, 3)).float(),\n",
    "                                 face_normals_idx=torch.randint(0, VN, (B, F, Fsz)))\n",
    "print(f'Instantiated mesh with batching {mesh_fixed.batching} and length {len(mesh_fixed)}')\n",
    "\n",
    "# LIST: all inputs are lists of equal length\n",
    "V2, F2 = 12, 20\n",
    "mesh_list = kal.rep.SurfaceMesh(\n",
    "    vertices=[torch.rand((V, 3)).float(), torch.rand((V2, 3)).float()],\n",
    "    faces=[torch.randint(0, V, (F, Fsz)).long(), torch.randint(0, V2, (F2, Fsz)).long()])\n",
    "print(f'Instantiated mesh with batching {mesh_list.batching} and length {len(mesh_list)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96d7d2be",
   "metadata": {},
   "source": [
    "## Inspecting SurfaceMesh Objects\n",
    "\n",
    "Working with many batched mesh attributes can be confusing, and details really matter. `SurfaceMesh` provides multiple ways to inspect its contents. These print statements also make it clear, which attributes can be auto-computed and how."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fb432078",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Mesh with batching NONE and length 1\n",
      "\n",
      "Attributes ['vertices', 'faces', 'face_vertices']\n",
      "\n",
      "Are face_normals set? False\n",
      "\n",
      "Are face_normals auto-computable? True\n",
      "\n",
      "Attributes (after accessing face_normals) ['vertices', 'faces', 'face_vertices', 'face_normals']\n",
      "\n",
      "Face normals        face_normals: [5, 3, 3] (torch.float32)[cpu]  \n",
      "\n",
      "\n",
      "Does the mesh have expected shapes? True\n",
      "SurfaceMesh object with batching strategy NONE\n",
      "            vertices: [10, 3] (torch.float32)[cpu]  \n",
      "               faces: [5, 3] (torch.int64)[cpu]  \n",
      "       face_vertices: [5, 3, 3] (torch.float32)[cpu]  \n",
      "        face_normals: [5, 3, 3] (torch.float32)[cpu]  \n",
      "            face_uvs: if possible, computed on access from: (uvs, face_uvs_idx)\n",
      "      vertex_normals: if possible, computed on access from: (faces, face_normals)\n",
      "     vertex_tangents: if possible, computed on access from: (faces, vertices, face_uvs)\n"
     ]
    }
   ],
   "source": [
    "# Get batching strategy and batch size (length)\n",
    "print(f'\\nMesh with batching {mesh.batching} and length {len(mesh)}')\n",
    "\n",
    "# Get currently set attributes\n",
    "print(f'\\nAttributes {mesh.get_attributes(only_tensors=True)}')\n",
    "\n",
    "# Check if an attribute is set without causing the mesh to auto-compute it\n",
    "print(f'\\nAre face_normals set? {mesh.has_attribute(\"face_normals\")}')\n",
    "\n",
    "# Check if the attribute likely can be auto-computed without actually trying to\n",
    "print(f'\\nAre face_normals auto-computable? {mesh.probably_can_compute_attribute(\"face_normals\")}')\n",
    "\n",
    "# Let's access face_normals and cause them to be computed\n",
    "mesh.face_normals\n",
    "print(f'\\nAttributes (after accessing face_normals) {mesh.get_attributes(only_tensors=True)}')\n",
    "\n",
    "# Check that face_normals are now set\n",
    "print(f'\\nFace normals{mesh.describe_attribute(\"face_normals\")}\\n')\n",
    "\n",
    "# Check if mesh tensor sizes follow expected conventions\n",
    "print(f'\\nDoes the mesh have expected shapes? {mesh.check_sanity()}')\n",
    "\n",
    "# Print mesh contents (and computable attributes)\n",
    "print(mesh)\n",
    "\n",
    "# We can also convert mesh to string with more details and tensor stats\n",
    "# print(f'\\nDetailed string of {mesh.to_string(detailed=True, print_stats=True)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10b7b49f",
   "metadata": {},
   "source": [
    "## Explicit API\n",
    "\n",
    "In addition to default `SurfaceMesh` API that allows compute on access and automatic caching, this class also supports alternative more verbose API that makes these actions explicit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d81f1dfe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Deleting face_vertices\n",
      "Deleting face_normals\n",
      "\n",
      "Mesh attributes after deletion: ['vertices', 'faces']\n",
      "\n",
      "Get face_normals without computing: None\n",
      "\n",
      "Computed face_normals shape is torch.Size([5, 3, 3])\n",
      "\n",
      "Did mesh cache computed face_normals (and face_vertices required to compute them)?\n",
      "False, False\n",
      "\n",
      "Did mesh cache computed face_normals (and face_vertices required to compute them)?\n",
      "True, True\n"
     ]
    }
   ],
   "source": [
    "# Let's delete attributes we just computed\n",
    "mesh.face_vertices = None\n",
    "mesh.face_normals = None\n",
    "\n",
    "# Check attributes were removed\n",
    "print(f'\\nMesh attributes after deletion: {mesh.get_attributes(only_tensors=True)}')\n",
    "\n",
    "# Get attribute without any auto-compute magic\n",
    "print(f'\\nGet face_normals without computing: {mesh.get_attribute(\"face_normals\")}')\n",
    "\n",
    "# Compute attribute, but don't cache\n",
    "face_normals = mesh.get_or_compute_attribute('face_normals', should_cache=False)\n",
    "print(f'\\nComputed face_normals shape is {face_normals.shape}')\n",
    "\n",
    "# Verify attributes were not cached\n",
    "print('\\nDid mesh cache computed face_normals (and face_vertices required to compute them)?')\n",
    "print(f'{mesh.has_attribute(\"face_normals\")}, {mesh.has_attribute(\"face_vertices\")}')\n",
    "\n",
    "# Compute and cache\n",
    "face_normals = mesh.get_or_compute_attribute('face_normals', should_cache=True)\n",
    "\n",
    "print('\\nDid mesh cache computed face_normals (and face_vertices required to compute them)?')\n",
    "print(f'{mesh.has_attribute(\"face_normals\")}, {mesh.has_attribute(\"face_vertices\")}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "787ebbdc",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Importing Data\n",
    "\n",
    "Since version 0.14.0, kaolin `obj` and `usd` importers return a `SurfaceMesh`, which is nearly backward-compatible with the previous `named_tuple` return type, while providing mutability and convenient data management. \n",
    "\n",
    "**Porting from earlier versions:** If porting from kaolin<=0.13.0, `obj` importer now correctly uses `face_normals_idx` (previously `face_normals`) to refer to the face-vertex indices into normals and `normals` (previously `vertex_normals`) to refer to the normals array that may or may not have the same number of elements as vertices. In addition, `materials` are now imported in name-sorted order and `material_order` has been replaced with `material_assignments` tensor of shape `(num_faces,)`, with integer value indicating the material index assigned to the corresponding face."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f316587e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Mesh imported from obj: SurfaceMesh object with batching strategy NONE\n",
      "            vertices: [42, 3] (torch.float32)[cpu]  \n",
      "               faces: [80, 3] (torch.int64)[cpu]  \n",
      "             normals: [80, 3] (torch.float32)[cpu]  \n",
      "    face_normals_idx: [80, 3] (torch.int64)[cpu]  \n",
      "                 uvs: [63, 2] (torch.float32)[cpu]  \n",
      "        face_uvs_idx: [80, 3] (torch.int64)[cpu]  \n",
      "material_assignments: [80] (torch.int16)[cpu]  \n",
      "           materials: list of length 1\n",
      "       face_vertices: if possible, computed on access from: (faces, vertices)\n",
      "        face_normals: if possible, computed on access from: (normals, face_normals_idx) or (vertices, faces)\n",
      "            face_uvs: if possible, computed on access from: (uvs, face_uvs_idx)\n",
      "      vertex_normals: if possible, computed on access from: (faces, face_normals)\n",
      "     vertex_tangents: if possible, computed on access from: (faces, vertices, face_uvs)\n",
      "\n",
      "Mesh imported from usd: SurfaceMesh object with batching strategy NONE\n",
      "            vertices: [42, 3] (torch.float32)[cpu]  \n",
      "               faces: [80, 3] (torch.int64)[cpu]  \n",
      "        face_normals: [80, 3, 3] (torch.float32)[cpu]  \n",
      "                 uvs: [240, 2] (torch.float32)[cpu]  \n",
      "        face_uvs_idx: [80, 3] (torch.int64)[cpu]  \n",
      "material_assignments: [80] (torch.int16)[cpu]  \n",
      "           materials: list of length 1\n",
      "       face_vertices: if possible, computed on access from: (faces, vertices)\n",
      "            face_uvs: if possible, computed on access from: (uvs, face_uvs_idx)\n",
      "      vertex_normals: if possible, computed on access from: (faces, face_normals)\n",
      "     vertex_tangents: if possible, computed on access from: (faces, vertices, face_uvs)\n",
      "dict_keys(['batching', 'allow_auto_compute', 'unset_attributes_return_none', 'materials', 'vertices', 'face_normals', 'uvs', 'faces', 'face_uvs_idx', 'material_assignments'])\n"
     ]
    }
   ],
   "source": [
    "import_args = {'with_materials' : True, 'with_normals' : True}\n",
    "\n",
    "# Let's import a single mesh from OBJ\n",
    "mesh_obj = kal.io.obj.import_mesh(sample_mesh_path('ico_flat.obj'), **import_args)\n",
    "\n",
    "# Let's import the same mesh from its USD version\n",
    "mesh_usd = kal.io.usd.import_mesh(sample_mesh_path('ico_flat.usda'), **import_args)\n",
    "\n",
    "# Let's inspect contents of both meshes (notice consistent naming of attributes)\n",
    "print(f'\\nMesh imported from obj: {mesh_obj}')\n",
    "print(f'\\nMesh imported from usd: {mesh_usd}')\n",
    "\n",
    "# Note: if you prefer to work with raw values, SurfaceMesh is convertible to dict\n",
    "mesh_dict = mesh_usd.as_dict()\n",
    "print(mesh_dict.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdfec845",
   "metadata": {},
   "source": [
    "Although geometrically these objects are the same, you will notice that USD stroes UVs and normals differently from OBJ, resulting in different imported arrays. Despite these differences, actual UVs and normals auto-computed and assigned to faces are actually the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5733dc8d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Are face_uvs same? True\n",
      "\n",
      "Are face_normals same? True\n"
     ]
    }
   ],
   "source": [
    "print(f'\\nAre face_uvs same? {torch.allclose(mesh_obj.face_uvs, mesh_usd.face_uvs, atol=1e-4)}')\n",
    "print(f'\\nAre face_normals same? {torch.allclose(mesh_obj.face_normals, mesh_usd.face_normals, atol=1e-4)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c53e057",
   "metadata": {},
   "source": [
    "## Working with Batches\n",
    "\n",
    "`SurfaceMesh` objects can be converted between batching strategies, as long as it is possible (for example list of meshes of variable topology cannot be converted to `Batching.FIXED`). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4f5476b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SurfaceMesh object with batching strategy FIXED\n",
      "            vertices: [1, 42, 3] (torch.float32)[cpu]  \n",
      "               faces: [80, 3] (torch.int64)[cpu]  \n",
      "             normals: [1, 80, 3] (torch.float32)[cpu]  \n",
      "    face_normals_idx: [1, 80, 3] (torch.int64)[cpu]  \n",
      "        face_normals: [1, 80, 3, 3] (torch.float32)[cpu]  \n",
      "                 uvs: [1, 63, 2] (torch.float32)[cpu]  \n",
      "        face_uvs_idx: [1, 80, 3] (torch.int64)[cpu]  \n",
      "            face_uvs: [1, 80, 3, 2] (torch.float32)[cpu]  \n",
      "material_assignments: [1, 80] (torch.int16)[cpu]  \n",
      "           materials: [\n",
      "                      0: list of length 1\n",
      "                      ]\n",
      "       face_vertices: if possible, computed on access from: (faces, vertices)\n",
      "      vertex_normals: if possible, computed on access from: (faces, face_normals)\n",
      "     vertex_tangents: if possible, computed on access from: (faces, vertices, face_uvs)\n",
      "\n",
      "SurfaceMesh object with batching strategy LIST\n",
      "            vertices: [\n",
      "                      0: [42, 3] (torch.float32)[cpu]  \n",
      "                      ]\n",
      "               faces: [\n",
      "                      0: [80, 3] (torch.int64)[cpu]  \n",
      "                      ]\n",
      "        face_normals: [\n",
      "                      0: [80, 3, 3] (torch.float32)[cpu]  \n",
      "                      ]\n",
      "                 uvs: [\n",
      "                      0: [240, 2] (torch.float32)[cpu]  \n",
      "                      ]\n",
      "        face_uvs_idx: [\n",
      "                      0: [80, 3] (torch.int64)[cpu]  \n",
      "                      ]\n",
      "            face_uvs: [\n",
      "                      0: [80, 3, 2] (torch.float32)[cpu]  \n",
      "                      ]\n",
      "material_assignments: [\n",
      "                      0: [80] (torch.int16)[cpu]  \n",
      "                      ]\n",
      "           materials: [\n",
      "                      0: list of length 1\n",
      "                      ]\n",
      "       face_vertices: if possible, computed on access from: (faces, vertices)\n",
      "      vertex_normals: if possible, computed on access from: (faces, face_normals)\n",
      "     vertex_tangents: if possible, computed on access from: (faces, vertices, face_uvs)\n"
     ]
    }
   ],
   "source": [
    "# Convert unbatched mesh to most commonly used FIXED batching\n",
    "mesh_obj.to_batched()  # Shortcut for mesh_usd.set_batching(SurfaceMesh.Batching.FIXED)\n",
    "print(mesh_obj)\n",
    "\n",
    "# Convert mesh to list batching\n",
    "mesh_usd.set_batching(SurfaceMesh.Batching.LIST)\n",
    "print(f'\\n{mesh_usd}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6c2a73c",
   "metadata": {},
   "source": [
    "We can also concatenate meshes of any batching strategy, with the output using `FIXED` (if `fixed_toplogy`) or `LIST` batching. Errors will be raised if concatentation is not possible for `vertices` or `faces`, and other attributes will be handled if possible. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9cdad2dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Cannot cat uvs arrays of given shapes; trying to concatenate face_uvs instead, due to: stack expects each tensor to be equal size, but got [63, 2] at entry 0 and [240, 2] at entry 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SurfaceMesh object with batching strategy FIXED\n",
      "            vertices: [2, 42, 3] (torch.float32)[cpu]  \n",
      "               faces: [80, 3] (torch.int64)[cpu]  \n",
      "        face_normals: [2, 80, 3, 3] (torch.float32)[cpu]  \n",
      "            face_uvs: [2, 80, 3, 2] (torch.float32)[cpu]  \n",
      "material_assignments: [2, 80] (torch.int16)[cpu]  \n",
      "           materials: [\n",
      "                      0: list of length 1\n",
      "                      1: list of length 1\n",
      "                      ]\n",
      "       face_vertices: if possible, computed on access from: (faces, vertices)\n",
      "      vertex_normals: if possible, computed on access from: (faces, face_normals)\n",
      "     vertex_tangents: if possible, computed on access from: (faces, vertices, face_uvs)\n"
     ]
    }
   ],
   "source": [
    "mesh = SurfaceMesh.cat([mesh_obj, mesh_usd], fixed_topology=True)\n",
    "\n",
    "# Notice that the concatenated mesh:\n",
    "# 1. does not have uvs, as those could not be stacked, but face_uvs were computed and stacked instead.\n",
    "# 2. does not have normals, as only first mesh had them, but face_normals were computed and stacked. \n",
    "print(mesh)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0091b207",
   "metadata": {},
   "source": [
    "With `fixed_topology=False`, it is also possible to concatenate meshes of variable topology into a list representation. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "541462ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SurfaceMesh object with batching strategy LIST\n",
      "            vertices: [\n",
      "                      0: [42, 3] (torch.float32)[cpu]  \n",
      "                      1: [42, 3] (torch.float32)[cpu]  \n",
      "                      2: [482, 3] (torch.float32)[cpu]  \n",
      "                      ]\n",
      "               faces: [\n",
      "                      0: [80, 3] (torch.int64)[cpu]  \n",
      "                      1: [80, 3] (torch.int64)[cpu]  \n",
      "                      2: [960, 3] (torch.int64)[cpu]  \n",
      "                      ]\n",
      "        face_normals: [\n",
      "                      0: [80, 3, 3] (torch.float32)[cpu]  \n",
      "                      1: [80, 3, 3] (torch.float32)[cpu]  \n",
      "                      2: [960, 3, 3] (torch.float32)[cpu]  \n",
      "                      ]\n",
      "material_assignments: [\n",
      "                      0: [80] (torch.int16)[cpu]  \n",
      "                      1: [80] (torch.int16)[cpu]  \n",
      "                      2: [960] (torch.int16)[cpu]  \n",
      "                      ]\n",
      "           materials: [\n",
      "                      0: list of length 1\n",
      "                      1: list of length 1\n",
      "                      2: list of length 2\n",
      "                      ]\n",
      "       face_vertices: if possible, computed on access from: (faces, vertices)\n",
      "            face_uvs: if possible, computed on access from: (uvs, face_uvs_idx)\n",
      "      vertex_normals: if possible, computed on access from: (faces, face_normals)\n",
      "     vertex_tangents: if possible, computed on access from: (faces, vertices, face_uvs)\n",
      "\n",
      "Note that auto-compute is still supported, e.g. after access:\n",
      "      vertex_normals: [\n",
      "                      0: [42, 3] (torch.float32)[cpu]  \n",
      "                      1: [42, 3] (torch.float32)[cpu]  \n",
      "                      2: [482, 3] (torch.float32)[cpu]  \n",
      "                      ]\n"
     ]
    }
   ],
   "source": [
    "tmp = SurfaceMesh.cat([mesh, kal.io.usd.import_mesh(sample_mesh_path('pizza.usda'), **import_args)],\n",
    "                      fixed_topology=False)\n",
    "print(tmp)\n",
    "tmp.vertex_normals\n",
    "print(f'\\nNote that auto-compute is still supported, e.g. after access:')\n",
    "print(f'{tmp.describe_attribute(\"vertex_normals\")}')\n",
    "\n",
    "del tmp  # not needed later"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b7855ae",
   "metadata": {},
   "source": [
    "## Convenience Methods and Mutability\n",
    "\n",
    "Now let's see a few useful capabilities of `SurfaceMesh`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "640359ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SurfaceMesh object with batching strategy FIXED\n",
      "            vertices: [2, 42, 3] (torch.float32)[cuda:0]  \n",
      "               faces: [80, 3] (torch.int64)[cuda:0]  \n",
      "        face_normals: [2, 80, 3, 3] (torch.float32)[cuda:0]  \n",
      "            face_uvs: [2, 80, 3, 2] (torch.float32)[cuda:0]  \n",
      "      vertex_normals: [2, 42, 3] (torch.float32)[cpu]  \n",
      "material_assignments: [2, 80] (torch.int16)[cuda:0]  \n",
      "           materials: [\n",
      "                      0: list of length 1\n",
      "                      1: list of length 1\n",
      "                      ]\n",
      "       face_vertices: if possible, computed on access from: (faces, vertices)\n",
      "     vertex_tangents: if possible, computed on access from: (faces, vertices, face_uvs)\n",
      "            vertices: [2, 42, 3] (torch.float32)[cuda:0]  - [min -1.0000, max 1.0000, mean -0.0000] \n",
      "            vertices: [2, 42, 3] (torch.float32)[cuda:0]  - [min -0.5000, max 0.5000, mean -0.0000] \n",
      "dict_keys(['vertices', 'faces', 'face_normals', 'face_uvs', 'vertex_normals', 'material_assignments'])\n"
     ]
    }
   ],
   "source": [
    "# Recall that mesh contains two fixed topology meshes\n",
    "\n",
    "# Let's move it to cuda (you can also specify particular attributes to move)\n",
    "mesh = mesh.cuda()\n",
    "\n",
    "# Let's say we actually don't need vertex_normals on the GPU\n",
    "mesh = mesh.cpu(attributes=['vertex_normals'])\n",
    "print(mesh)\n",
    "\n",
    "# We can also directly set mesh attributes, for example:\n",
    "print(mesh.describe_attribute('vertices', print_stats=True))\n",
    "mesh.vertices = kal.ops.pointcloud.center_points(mesh.vertices, normalize=True)\n",
    "print(mesh.describe_attribute('vertices', print_stats=True))\n",
    "\n",
    "# Mesh also supports copy and deepcopy\n",
    "mesh_copy = copy.copy(mesh)\n",
    "mesh_copy = copy.deepcopy(mesh)\n",
    "\n",
    "# Finally, mesh can be converted to a simple dict\n",
    "mesh_dict = mesh.as_dict(only_tensors=True)\n",
    "print(mesh_dict.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdd71c5b",
   "metadata": {},
   "source": [
    "## Optimization and Gradients\n",
    "\n",
    "It is also possible to optimize mesh attributes by going through auto-computed attributes. However, take care to set `requires_grad` before auto-computed attribute is cached. This causes auto-computed attributes to be computed every time, allowing gradients to flow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9b053b52",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Has face_vertices? False\n",
      "computed face_vertices: [2, 80, 3, 3] (torch.float32)[cuda:0]  \n",
      "Were face_vertices cached? False\n",
      "\n",
      "Sample loss 1.7655433416366577\n"
     ]
    }
   ],
   "source": [
    "# Let's try to optimize vertices\n",
    "mesh.vertices.requires_grad = True\n",
    "\n",
    "# Check that mesh does not cache face_vertices\n",
    "print(f'Has face_vertices? {mesh.has_attribute(\"face_vertices\")}')\n",
    "\n",
    "# Check that we can actually compute them\n",
    "face_vertices = mesh.face_vertices\n",
    "print_tensor(face_vertices, name='computed face_vertices')\n",
    "\n",
    "# However, because mesh.vertices.requires_grad, this value is not cached\n",
    "print(f'Were face_vertices cached? {mesh.has_attribute(\"face_vertices\")}')\n",
    "\n",
    "# Now we can use mesh.face_vertices in a loss function, while optimizing mesh.vertices, e.g.:\n",
    "sample_pt_cloud = torch.randn((2, 100, 3), dtype=mesh.vertices.dtype, device=mesh.vertices.device)\n",
    "sample_loss = kal.metrics.trianglemesh.point_to_mesh_distance(sample_pt_cloud, mesh.face_vertices)[0].mean()\n",
    "print(f'\\nSample loss {sample_loss}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67097c5d",
   "metadata": {},
   "source": [
    "## Exporting\n",
    "\n",
    "Automatic conversion to `LIST` batching also makes it easy to export a batch of USD meshes to file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "bf0591c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Exporting to USD: 100%|█████████████████████████████████████████████| 2/2 [00:00<00:00, 190.03mesh/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SurfaceMesh object with batching strategy LIST\n",
      "            vertices: [\n",
      "                      0: [42, 3] (torch.float32)[cpu]  \n",
      "                      1: [42, 3] (torch.float32)[cpu]  \n",
      "                      ]\n",
      "               faces: [\n",
      "                      0: [80, 3] (torch.int64)[cpu]  \n",
      "                      1: [80, 3] (torch.int64)[cpu]  \n",
      "                      ]\n",
      "        face_normals: [\n",
      "                      0: [80, 3, 3] (torch.float32)[cpu]  \n",
      "                      1: [80, 3, 3] (torch.float32)[cpu]  \n",
      "                      ]\n",
      "       face_vertices: if possible, computed on access from: (faces, vertices)\n",
      "            face_uvs: if possible, computed on access from: (uvs, face_uvs_idx)\n",
      "      vertex_normals: if possible, computed on access from: (faces, face_normals)\n",
      "     vertex_tangents: if possible, computed on access from: (faces, vertices, face_uvs)\n",
      "True\n",
      "True\n",
      "True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "/kaolin/kaolin/io/usd/mesh.py:371: UserWarning: Some child prims for /World/Meshes/mesh_0 are missing uvs; skipping importing uvs.\n",
      "  warnings.warn(f'Some child prims for {scene_path} are missing {k}; skipping importing {k}.', UserWarning)\n",
      "/kaolin/kaolin/io/usd/mesh.py:371: UserWarning: Some child prims for /World/Meshes/mesh_0 are missing face_uvs_idx; skipping importing face_uvs_idx.\n",
      "  warnings.warn(f'Some child prims for {scene_path} are missing {k}; skipping importing {k}.', UserWarning)\n",
      "/kaolin/kaolin/io/usd/mesh.py:371: UserWarning: Some child prims for /World/Meshes/mesh_1 are missing uvs; skipping importing uvs.\n",
      "  warnings.warn(f'Some child prims for {scene_path} are missing {k}; skipping importing {k}.', UserWarning)\n",
      "/kaolin/kaolin/io/usd/mesh.py:371: UserWarning: Some child prims for /World/Meshes/mesh_1 are missing face_uvs_idx; skipping importing face_uvs_idx.\n",
      "  warnings.warn(f'Some child prims for {scene_path} are missing {k}; skipping importing {k}.', UserWarning)\n"
     ]
    }
   ],
   "source": [
    "mesh = mesh.set_batching(SurfaceMesh.Batching.LIST)\n",
    "\n",
    "# Note: you can only run this once due to USD caching; restart Kernel to rerun cell without errors\n",
    "kal.io.usd.export_meshes('/tmp/out.usd', vertices=mesh.vertices, faces=mesh.faces, face_normals=mesh.face_normals)\n",
    "\n",
    "# Verify we can read back the same meshes we exported\n",
    "imported_meshes = SurfaceMesh.cat(\n",
    "    kal.io.usd.import_meshes('/tmp/out.usd', with_normals=True), fixed_topology=False)\n",
    "mesh = mesh.cpu()\n",
    "print(imported_meshes)\n",
    "print(kal.utils.testing.contained_torch_equal(mesh.vertices, imported_meshes.vertices, approximate=True))\n",
    "print(kal.utils.testing.contained_torch_equal(mesh.faces, imported_meshes.faces))\n",
    "print(kal.utils.testing.contained_torch_equal(mesh.face_normals, imported_meshes.face_normals, approximate=True))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
